{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets, transforms\n",
    "import torch.utils.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        # torch.nn.Conv2d(in_channels, out_channels, kernel_size,\n",
    "        # stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')\n",
    "        self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
    "        self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
    "        self.dropout1 = nn.Dropout2d(0.25)\n",
    "        self.dropout2 = nn.Dropout2d(0.5)\n",
    "        self.fc1 = nn.Linear(9216, 128)\n",
    "        self.fc2 = nn.Linear(128, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.conv2(x)\n",
    "        x = F.max_pool2d(x, 2)\n",
    "        x = self.dropout1(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.dropout2(x)\n",
    "        x = self.fc2(x)\n",
    "        output = F.log_softmax(x, dim=1)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, device, train_loader, optimizer, epoch):\n",
    "    model.train()\n",
    "    total = 0\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = F.nll_loss(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        total += len(data)\n",
    "        progress = math.ceil(batch_idx / len(train_loader) * 50)\n",
    "        print(\"\\rTrain epoch %d: %d/%d, [%-51s] %d%%\" %\n",
    "              (epoch, total, len(train_loader.dataset),\n",
    "               '-' * progress + '>', progress * 2), end='')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(model, device, test_loader):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss\n",
    "            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "\n",
    "    print('\\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(\n",
    "        test_loss, correct, len(test_loader.dataset),\n",
    "        100. * correct / len(test_loader.dataset)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main():\n",
    "    epochs = 2\n",
    "    batch_size = 64\n",
    "    torch.manual_seed(0)\n",
    "\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        datasets.MNIST('../data/MNIST', train=True, download=False,\n",
    "                       transform=transforms.Compose([\n",
    "                           transforms.ToTensor(),\n",
    "                           transforms.Normalize((0.1307,), (0.3081,))\n",
    "                       ])),\n",
    "        batch_size=batch_size, shuffle=True)\n",
    "    test_loader = torch.utils.data.DataLoader(\n",
    "        datasets.MNIST('../data/MNIST', train=False, download=False, transform=transforms.Compose([\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize((0.1307,), (0.3081,))\n",
    "        ])),\n",
    "        batch_size=1000, shuffle=True)\n",
    "\n",
    "    model = Net().to(device)\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=0.025, momentum=0.9)\n",
    "\n",
    "    for epoch in range(1, epochs + 1):\n",
    "        train(model, device, train_loader, optimizer, epoch)\n",
    "        test(model, device, test_loader)\n",
    "\n",
    "#     torch.save(model.state_dict(), \"mnist_cnn.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train epoch 1: 60000/60000, [-------------------------------------------------->] 100%\n",
      "Test: average loss: 0.0644, accuracy: 9798/10000 (98%)\n",
      "Train epoch 2: 60000/60000, [-------------------------------------------------->] 100%\n",
      "Test: average loss: 0.0453, accuracy: 9861/10000 (99%)\n"
     ]
    }
   ],
   "source": [
    "main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
